setwd("/data/home/lyang/Visium_RCTD")

dyn.load("/data/home/bioinfo/programs/hdf5-1.12.0/hdf5/lib/libhdf5_hl.so.200")
.libPaths(c("/data/home/lyang/R/x86_64-redhat-linux-gnu-library/4.1",
            "/data/home/bioinfo/R/R4.1.0/library_B3.13",
            "/usr/lib64/R/library","/usr/share/R/library") )
knitr::opts_knit$set(root.dir = "/data/home/lyang/Visium_RCTD")
setwd("/data/home/lyang/Visium_RCTD")
init_require_packages <- function(){
  # library(Matrix)
  library(spacexr)
  library(Seurat)
  library(readxl)
  library(SeuratData)
  library(stringr) 
  library(patchwork)
  library(dplyr)
  library(ggplot2)
  library(SeuratDisk)
  library(doParallel)
}

init_require_packages()


dataset_name = "cell2location"
# dataset_name = "Tabula" "cell2location"
sc_dataset_file = paste(dataset_name,"sc_dataset.rds",sep ="_")



if(file.exists(sc_dataset_file)){
  rds_file= sc_dataset_file
  sc_dataset <- readRDS(rds_file)
  }else{
  
  import_sc_dataset <- function()
  {
    # Convert("D:/minor_intership_data/sc.h5ad", dest = "h5seurat", overwrite = TRUE)
    sc_dataset <- LoadH5Seurat("D:/minor_intership_data/sc.h5seurat",assays  = "RNA")
    sc_dataset
    
    saveRDS(sc_dataset,sc_dataset_file)
    return(sc_dataset)
  }
  sc_dataset = import_sc_dataset()
  sc_dataset
  colnames(x = sc_dataset[[]])
    
}
qc_filter_sc <- function(sc_dataset,dataset_name){
  colnames(x = sc_dataset[[]])
  if(dataset_name == "cell2location"){
      sc_dataset[['nCount_RNA']] = sc_dataset$n_counts
  }else if(dataset_name == "Tabula"){
    sc_dataset[['nCount_RNA']] = sc_dataset$n_counts_UMIs
  }
 
  
  # head(sc_dataset@meta.data)
  # The [[ operator can add columns to object metadata. This is a great place to stash QC stats
  # sc_dataset[["percent.mt"]] <- PercentageFeatureSet(sc_dataset, pattern = "^MT-")
    mt.genes <- grep(pattern = '^MT-', x = rownames(x = sc_dataset@assays$RNA),value = TRUE)

  percent.mt <- Matrix::colSums(sc_dataset@assays$RNA[mt.genes, ]) / Matrix::colSums(sc_dataset@assays$RNA)

  sc_dataset <- AddMetaData(object = sc_dataset,metadata = percent.mt,
              col.name = 'percent.mt')
  
  # filter cells based on QC metrics
  sc_dataset_filtered <- subset(sc_dataset, subset = n_genes > 200 & n_genes < 5000 & percent.mt < 0.05 )
  return(list("sc_dataset"= sc_dataset,"sc_dataset_filtered"= sc_dataset_filtered))
  
}
# i filter cells that have gene number over 2,500 or less than 200
# i filter cells that have >5% mitochondrial counts
# these qc requirement comes from Seurat tutorial, and only retain 9532/53275 cells
a_list = qc_filter_sc(sc_dataset,dataset_name)
sc_dataset = a_list$sc_dataset
sc_dataset_filtered = a_list$sc_dataset_filtered

sc_dataset_filtered
## An object of class Seurat 
## 10237 features across 72566 samples within 1 assay 
## Active assay: RNA (10237 features, 0 variable features)
##  2 dimensional reductions calculated: pca, umap
# sc_dataset_filtered$CellType
# # Visualize QC metrics as a violin plot
VlnPlot(sc_dataset, features = c("n_genes", "nCount_RNA"), ncol = 2)

# crutinize which cell types are delete after filtering, compare the cell types composition before and after qc filtering, see if this changed

# dataset_name = "cell2location34" "Tabula"
dataset_name = "cell2location34"
if(dataset_name == "cell2location"){
    cell_type_table = as.data.frame(table(sc_dataset[['CellType']]))
    cell_type_table$after = as.data.frame(table(sc_dataset_filtered[['CellType']]))[,2]

}else if(dataset_name == "Tabula"){
  cell_type_table = as.data.frame(table(sc_dataset[['cell_ontology_class']]))
  cell_type_table$after = as.data.frame(table(sc_dataset_filtered[['cell_ontology_class']]))[,2]

}else if(dataset_name == "cell2location34"){
  cell_type_table = as.data.frame(table(sc_dataset[['Subset']]))
  cell_type_table$after = as.data.frame(table(sc_dataset_filtered[['Subset']]))[,2]

}


colnames(cell_type_table) = c("cell_type","before","after")


cell_type_table[,-1] = apply(cell_type_table[,-1],2,function(x) x/sum(x))
library(reshape2)
cell_type_table <- melt(cell_type_table, id.vars = "cell_type",
                        variable.name = "state",  value.name = "proportion")

ggplot(data = cell_type_table, mapping = aes(x = cell_type, 
                                             y = proportion, 
                                             colour = state,group = state )) + 
geom_point() +geom_line(aes(color=state))+
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = .5))

Quick data exploration:

# annotation = "cell2location34" Tabula
# annotation = "cell2location34"
dataset_name = "cell2location34"
if(dataset_name == "cell2location34"){

  cell_type_table = as.data.frame(table(sc_dataset$Subset))

}else if(dataset_name == "Tabula"){
  cell_type_table = as.data.frame(table(sc_dataset$cell_ontology_class))
}else if(dataset_name == "cell2location44"){
  cell_type_table = as.data.frame(table(sc_dataset$CellType))
}




colnames(cell_type_table) = c("cell_type","frequency")
library(ggplot2)
# Basic barplot
p <- ggplot(data=cell_type_table, aes(x=cell_type, y=frequency,fill=cell_type)) +
  geom_bar(stat="identity")+
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = .5))+
       geom_text(aes(label=frequency), position=position_dodge(width=0.9), vjust=-0.25)


# ggsave(p,filename = "cell_type_histogram.pdf",width = 15,
#        height = 20)

print(p)

delete_rare_cells <- function(sc_dataset_filtered)
{
  colnames(x = sc_dataset_filtered[[]])
  
  if(dataset_name == "cell2location"){
  Idents(sc_dataset_filtered) <- 'CellType'
  }else if(dataset_name == "Tabula"){
  Idents(sc_dataset_filtered) <- 'cell_ontology_class'
  }else if(dataset_name == "cell2location34"){
  Idents(sc_dataset_filtered) <- 'Subset'
  }


  
  rare_cell_types = names(table(Idents(sc_dataset_filtered))[!table(Idents(sc_dataset_filtered)) > 25])
  unique(Idents(sc_dataset_filtered))
  # delete cells belong to rare cell types
  for (char in rare_cell_types)
  {
      if(dataset_name == "cell2location"){
        sc_dataset_filtered <- subset(sc_dataset_filtered, subset = CellType != char)  
      }else if(dataset_name == "Tabula"){
        sc_dataset_filtered <- subset(sc_dataset_filtered, subset = cell_ontology_class != char)  
      }else if(dataset_name == "cell2location34"){
                sc_dataset_filtered <- subset(sc_dataset_filtered, subset = Subset != char)  
      }
    
  }  
  
  table(Idents(sc_dataset_filtered))
  
  return(sc_dataset_filtered)
}
sc_dataset_filtered = delete_rare_cells(sc_dataset_filtered)
sc_dataset_filtered
## An object of class Seurat 
## 10237 features across 72548 samples within 1 assay 
## Active assay: RNA (10237 features, 0 variable features)
##  2 dimensional reductions calculated: pca, umap
sc_dataset.downsampled_file = paste(dataset_name,"sc_dataset.downsampled.rds",sep ="_")


if (file.exists(sc_dataset.downsampled_file)){
  rds_file = sc_dataset.downsampled_file
  sc_dataset_filtered.downsampled <- readRDS(rds_file)
}else
{
  
subsample_cells <- function(sc_dataset)
{
  if(dataset_name == "cell2location"){
  Idents(sc_dataset) <- 'CellType'
  }else if(dataset_name == "Tabula"){
  Idents(sc_dataset) <- 'cell_ontology_class'
  }else if(dataset_name == "cell2location34"){
  Idents(sc_dataset) <- 'Subset'
  }
  
  table(Idents(sc_dataset))
  unique(Idents(sc_dataset))
  
  
  cell.list <- WhichCells(sc_dataset, idents = unique(Idents(sc_dataset)), 
                          downsample = 3000)
  sc_dataset.downsampled <- sc_dataset[, cell.list]
  # table(sc_dataset.downsampled$CellType)
  return(sc_dataset.downsampled)
  
}
# subsample fixed numbers of cells from differently sized cell types in a seurat object.
sc_dataset_filtered.downsampled = subsample_cells(sc_dataset_filtered)
saveRDS(sc_dataset_filtered.downsampled,sc_dataset.downsampled_file)


}
Create_Reference_object <- function(sc_dataset_filtered)
{
  
  # load in counts matrix
  counts = as.matrix(sc_dataset_filtered@assays$RNA@counts)
  
  # counts = as.matrix(x = GetAssayData(object = sc_dataset.downsampled, assay = "RNA", slot = "counts"))
  # downsampled = SampleUMI(data = counts, max.umi = 1000, upsample = TRUE, verbose = TRUE)
  colnames(x = sc_dataset_filtered[[]])
  #rownames(counts) <- counts[,1]; counts[,1] <- NULL # Move first column to rownames
  if(dataset_name == "cell2location"){
  meta_data <- sc_dataset_filtered@meta.data[,c("CellType",'nCount_RNA')] 
    cell_types <- meta_data$CellType
  }else if(dataset_name == "Tabula"){
  meta_data <- sc_dataset_filtered@meta.data[,c("cell_ontology_class",'nCount_RNA')]
    cell_types <- meta_data$cell_ontology_class
  }else if(dataset_name == "cell2location34"){
  meta_data <- sc_dataset_filtered@meta.data[,c("Subset",'nCount_RNA')]
    cell_types <- meta_data$Subset
  }
  
# load in meta_data (barcodes, clusters, and nUMI)

  cell_types = str_replace_all( cell_types,"/", "_") 
  names(cell_types) <- rownames(meta_data)#meta_data$barcode # create cell_types named list
  cell_types <- as.factor(cell_types) # convert to factor data type
  cell_types <- droplevels(cell_types)
  nUMI <- meta_data$nCount_RNA; names(nUMI) <- rownames(meta_data)#meta_data$barcode # create nUMI named list
  ### Create the Reference object
  reference <- Reference(counts, cell_types, nUMI)
  ## Examine reference object (optional)
  print(dim(reference@counts)) 
  # gene number; cell number
  table(reference@cell_types) 
  return(reference)
}

reference = Create_Reference_object(sc_dataset_filtered.downsampled)
## Warning in if (class(counts) != "dgCMatrix") {: the condition has length > 1 and
## only the first element will be used
## Warning in if (class(counts) != "matrix") tryCatch({: the condition has length >
## 1 and only the first element will be used
## [1] 10237 47443
spatial_dataset_name = "Visium"
puck_file = paste(spatial_dataset_name,"puck.rds",sep ="_")

if (file.exists(puck_file)){
  rds_file = puck_file
  puck <- readRDS(rds_file)
}else{
  creat_spatial_RCTD <- function(){
  rds_file = "Visium_spatial.rds"
  spatial <- readRDS(rds_file)

  #https://github.com/dmcable/RCTD/issues/26
  # coords <- spatial@images$slice1@image
  coords <-GetTissueCoordinates(spatial, scale = NULL) 
  # pulling unscaled tissue coordinates
  #spatialcounts <- as.matrix(GetAssayData(spatial, assay = "Spatial", slot = "counts"))
  spatialcounts <- as.matrix(spatial@assays$Spatial@counts)
  ### Create SpatialRNA object
  puck <- SpatialRNA(coords, spatialcounts)
  saveRDS(puck,puck_file)
  return(puck)
}

  puck = creat_spatial_RCTD()
}
library(quadprog)

RCTD_file = paste(dataset_name,"full_RCTD.rds",sep ="_")
# RCTD_file = paste(dataset_name,"multi_RCTD.rds",sep ="_")
# RCTD_file = paste(dataset_name,"doublet_RCTD.rds",sep ="_")

if(file.exists(RCTD_file)){
  rds_file = RCTD_file
  myRCTD <- readRDS(rds_file)  
}else{
  myRCTD <- create.RCTD(puck, reference, max_cores = 18)
  myRCTD <- run.RCTD(myRCTD, doublet_mode = 'full')
  # myRCTD <- run.RCTD(myRCTD, doublet_mode = 'multi')
  # myRCTD2 <- run.RCTD(myRCTD, doublet_mode = 'doublet')
  saveRDS(myRCTD,RCTD_file)  
}

Including Plots

You can also embed plots, for example:

show_results <- function(myRCTD)
{
  results <- myRCTD@results
  # normalize the cell type proportions to sum to 1.
  norm_weights = normalize_weights(results$weights) 
  head(norm_weights)
  
  cell_type_names <- myRCTD@cell_type_info$info[[2]] #list of cell type names
  spatialRNA <- myRCTD@spatialRNA
  resultsdir <- 'Tabula_Visium_RCTD_Plots' ## you may change this to a more accessible directory on your computer. 
  dir.create(resultsdir)
  #> Warning in dir.create(resultsdir): 'RCTD_Plots' already exists
  # make the plots 
  # Plots the confident weights for each cell type as in full_mode (saved as 
  # 'results/cell_type_weights_unthreshold.pdf')
  plot_weights(cell_type_names, spatialRNA, resultsdir, norm_weights) 
  
  # Plots all weights for each cell type as in full_mode. (saved as 
  # 'results/cell_type_weights.pdf')
  plot_weights_unthreshold(cell_type_names, spatialRNA, resultsdir, norm_weights) 
  
  
  
  # Plots the number of  pixels of each cell type where they confidently located in under 'full_mode'. (saved as 
  # 'results/cell_type_occur.pdf')
  plot_cond_occur(cell_type_names, resultsdir, norm_weights, spatialRNA)
  return(norm_weights)
  
}
norm_weights_object = show_results(myRCTD)

creat_spatial_RCTD_GeoMx <- function()
{
  coords <- read_excel("D:/minor_intership_data/GeoMx WTA Human Lymph Node Data_count_results/Export4_NormalizationQ3.xlsx ",sheet = "SegmentProperties") 
  coords <- as.data.frame(coords[,c("SegmentDisplayName","ROICoordinateX","ROICoordinateY")]) 
  rownames(coords) = coords[,"SegmentDisplayName"]
  coords$SegmentDisplayName <- NULL
  counts <- read_excel("D:/minor_intership_data/GeoMx WTA Human Lymph Node Data_count_results/Export4_NormalizationQ3.xlsx ",sheet = "TargetCountMatrix") 
  counts <- as.data.frame(counts)
  rownames(counts) = counts[,1]
  counts$TargetName <- NULL
  puck <- SpatialRNA(coords, counts,require_int = FALSE)
  
  return(puck)
}
# puck = creat_spatial_RCTD_GeoMx()
# saveRDS(puck,'spatial_GeoMx.rds')
# rds_file='spatial_GeoMx.rds'
# puck <- readRDS(rds_file)

Note that the echo = FALSE parameter was added to the code chunk to prevent printing of the R code that generated the plot.

library(ggplot2)
library(SPOTlight)
normalized_weights <- normalize_weights(myRCTD@results$weights)
# normalize the cell type proportions to sum to 1.
class(normalized_weights)
## [1] "dgeMatrix"
## attr(,"package")
## [1] "Matrix"
norm_weights_matrix = as(normalized_weights,'matrix')
# mat <- res$mat
mat <- norm_weights_matrix
ct <- colnames(mat)
mat[mat < 0.1] <- 0


library(RColorBrewer)


if(dataset_name == "Tabula"){
  n <- 29
}else if(dataset_name == "cell2location"){
  n <- 44
}else if(dataset_name == "cell2location34"){
  n <- 34
}


colrs <- brewer.pal.info[brewer.pal.info$colorblind == TRUE, ]
col_vec = unlist(mapply(brewer.pal, colrs$maxcolors, rownames(colrs)))
col <- sample(col_vec, n)
rds_file = "Visium_spatial.rds"
spatial <- readRDS(rds_file)

test = spatial@assays$Spatial@counts@Dimnames[[2]] %in% rownames(norm_weights_matrix)
which(test == FALSE)
## [1]  464 1343
spatial@assays$Spatial@counts@Dimnames[[2]][464]
## [1] "ACTGTAGCACTTTGGA-1"
spatial@assays$Spatial@counts@Dimnames[[2]][1343]
## [1] "CCCGCCATGCTCCCGT-1"
spatial@meta.data$spot_name  <- rownames(spatial@meta.data)
spatial = subset(spatial,spot_name !="ACTGTAGCACTTTGGA-1" )
spatial = subset(spatial,spot_name !="CCCGCCATGCTCCCGT-1")


# Define color palette
# (here we use 'paletteMartin' from the 'colorBlindness' package)
paletteMartin <- col

pal <- colorRampPalette(paletteMartin)(length(ct))
names(pal) <- ct


pal_back <- pal


plot_3_region <- function(pal)
{
  for (char in names(pal)) {
  # print(char)
  if(char %in% c("T_CD4+_naive","FDC","B_naive") ){
    if(char == "T_CD4+_naive")
    {
      pal[char] = "#FFFF00"
    } else if (char == "FDC")
    {
      pal[char] = "#add8e6"
    } else if (char == "B_naive")
    {
      pal[char] = "#FF0000"
    }
    next
  }
  pal[char] = "#00008b"
}

  return(pal)
}
packageVersion("SPOTlight")
## [1] '0.99.4'
pal = plot_3_region(pal_back)

plotSpatialScatterpie(
    x = spatial,
    y = mat,
  cell_types = colnames(y),
  img = FALSE,
  scatterpie_alpha = 1,
  pie_scale = 0.4) +
  scale_fill_manual(
    values = pal,
    breaks = names(pal))+ scale_y_reverse()

plot_n_cell_type <- function(pal,col_vec,n, cell_type_vector)
{
  for (char in names(pal)) {
  # print(char)
  if(char %in%  cell_type_vector){
    next
  }
  pal[char] = "#00008b"
}

  return(pal)
}

cell_type_vector = c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC")

pal = plot_n_cell_type(pal_back,col_vec,6,cell_type_vector)
p1 = plotSpatialScatterpie(
    x = spatial,
    y = mat,
  cell_types = colnames(y),
  img = FALSE,
  scatterpie_alpha = 1,
  pie_scale = 0.4) +
  scale_fill_manual(
    values = pal,
    breaks = names(pal))+ scale_y_reverse()

pal = pal_back
# type(pal)
plotSpatialScatterpie(
    x = spatial,
    y = mat,
  cell_types = colnames(y),
  img = FALSE,
  scatterpie_alpha = 1,
  pie_scale = 0.4) +
  scale_fill_manual(
    values = pal,
    breaks = names(pal))+ scale_y_reverse()

rds_file='predictions.assay.rds'
predictions.assay <- readRDS(rds_file)

results <- myRCTD@results
# normalize the cell type proportions to sum to 1.
norm_weights = normalize_weights(results$weights) 

predictions.assay@data <- t(as.matrix(norm_weights))       
meta.features <- as.data.frame(colnames(as.matrix(norm_weights)) )
rownames(meta.features) = meta.features[,1]
meta.features$`colnames(as.matrix(norm_weights))` =NULL
predictions.assay@meta.features = meta.features

spatial[["predictions"]] <- predictions.assay

DefaultAssay(spatial) <- "predictions"
colnames(mat)
##  [1] "B_activated"      "B_Cycling"        "B_GC_DZ"          "B_GC_LZ"         
##  [5] "B_GC_prePB"       "B_IFN"            "B_mem"            "B_naive"         
##  [9] "B_plasma"         "B_preGC"          "DC_CCR7+"         "DC_cDC1"         
## [13] "DC_cDC2"          "DC_pDC"           "Endo"             "FDC"             
## [17] "ILC"              "Macrophages_M1"   "Macrophages_M2"   "Monocytes"       
## [21] "NK"               "NKT"              "T_CD4+"           "T_CD4+_naive"    
## [25] "T_CD4+_TfH"       "T_CD4+_TfH_GC"    "T_CD8+_CD161+"    "T_CD8+_cytotoxic"
## [29] "T_CD8+_naive"     "T_TfR"            "T_TIM3+"          "T_Treg"          
## [33] "VSMC"
if(dataset_name == "Tabula"){
  
p = SpatialFeaturePlot(spatial, features = c("stromal cell", "t cell"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)
print(p)
p = SpatialFeaturePlot(spatial, features = c("naive b cell", "b cell"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)
print(p)
  
}else if(dataset_name == "cell2location"){
  
 SpatialFeaturePlot(spatial, features = c("T_Treg", "DC_cDC2"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)
SpatialFeaturePlot(spatial, features = c("DC_cDC1", "DC_cDC2"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)

}else if(dataset_name == "cell2location34"){
  
# detach("package:SpatialExperiment", unload=TRUE)
  
p = SpatialFeaturePlot(spatial, features = c("B_GC_LZ", "T_CD4+_TfH_GC","B_GC_prePB","FDC"), pt.size.factor = 1.6, ncol = 4, crop = TRUE) + ggtitle("germinal center light zone")
print(p)
 # + scale_fill_continuous(limits = c(0, 1))

p = SpatialFeaturePlot(spatial, features = c("B_Cycling", "B_GC_DZ"), pt.size.factor = 1.6, ncol = 2, crop = TRUE)+ ggtitle("germinal center dark zone") 
print(p)
SpatialFeaturePlot(spatial, features = c("B_naive", "B_preGC"), pt.size.factor = 1.6, ncol = 2, crop = TRUE) + ggtitle("B follicle + pre GC") 
}

# table(sc_dataset@meta.data$Subset)
# Leave out selected cell types (B_GC_DZ, B_GC_LZ, B_GC_prePB, T_CD4+_TfH_GC, B_Cycling and FDC, these six subtypes are expected to be present in annotated GC (positive locations) one by one using ‘subset’ command.
# for (selected_type in c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
#                  "B_Cycling",  "FDC") )
# {
#   # selected_type = "B_GC_DZ"
#   Idents(sc_dataset_filtered.downsampled) <- 'Subset'
#   sc_reference <- subset(sc_dataset_filtered.downsampled, subset = Subset != selected_type)
#   table(sc_reference@meta.data$Subset)
#   reference = Create_Reference_object(sc_reference)
#   
#   RCTD_file = paste(dataset_name,"full_RCTD", selected_type, ".rds",sep ="_")
#   library(quadprog)
#   myRCTD <- create.RCTD(puck, reference, max_cores = 18)
#   myRCTD <- run.RCTD(myRCTD, doublet_mode = 'full')
#   saveRDS(myRCTD,RCTD_file)
#   print(RCTD_file)
# }
# getwd()
library(ggplot2)
library(SPOTlight)

manual_GC_annotation = read.csv("/data/home/lyang/scripts_git/manual_GC_annotation.csv")
rds_file = "Visium_spatial.rds"
spatial <- readRDS(rds_file)
spatial@meta.data$Barcode = rownames(spatial@meta.data)
spatial@meta.data = merge(spatial@meta.data, manual_GC_annotation, by = "Barcode", sort = FALSE)

spatial@meta.data[spatial@meta.data$GC == '',"GC"] = "no"

p2 <- SpatialDimPlot(spatial, group.by = "GC") 
# p1 + p2
print(p2)

print(p1)

for (selected_type in c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC") )
{
  RCTD_file = paste(dataset_name,"full_RCTD", selected_type, ".rds",sep ="_")
  rds_file = RCTD_file
  myRCTD_deleted <- readRDS(rds_file)   
  normalized_weights_deleted <- normalize_weights(myRCTD_deleted@results$weights)
  # normalize the cell type proportions to sum to 1.
  matrix_result_deleted = as(normalized_weights_deleted,'matrix')
  mat <- matrix_result_deleted
  ct <- colnames(mat)
  mat[mat < 0.1] <- 0
  
  
  library(RColorBrewer)
  
  n <- 34
  
  
  
  colrs <- brewer.pal.info[brewer.pal.info$colorblind == TRUE, ]
  col_vec = unlist(mapply(brewer.pal, colrs$maxcolors, rownames(colrs)))
  col <- sample(col_vec, n)
  rds_file = "Visium_spatial.rds"
  spatial <- readRDS(rds_file)
  
  # test = spatial@assays$Spatial@counts@Dimnames[[2]] %in% rownames(norm_weights_matrix)
  # which(test == FALSE)
  # spatial@assays$Spatial@counts@Dimnames[[2]][464]
  # spatial@assays$Spatial@counts@Dimnames[[2]][1343]
  spatial@meta.data$spot_name  <- rownames(spatial@meta.data)
  spatial = subset(spatial,spot_name !="ACTGTAGCACTTTGGA-1" )
  spatial = subset(spatial,spot_name !="CCCGCCATGCTCCCGT-1")
  
  
  # Define color palette
  # (here we use 'paletteMartin' from the 'colorBlindness' package)
  paletteMartin <- col
  
  pal <- colorRampPalette(paletteMartin)(length(ct))
  names(pal) <- ct
  
  
  pal_back <- pal
  
  plot_n_cell_type <- function(pal,col_vec,n, cell_type_vector)
  {
    for (char in names(pal)) {
    # print(char)
    if(char %in%  cell_type_vector){
      next
    }
    pal[char] = "#00008b"
  }
  
    return(pal)
  }
  
  cell_type_vector = c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                   "B_Cycling",  "FDC")
  
  pal = plot_n_cell_type(pal_back,col_vec,6,cell_type_vector)
  p1 = plotSpatialScatterpie(
      x = spatial,
      y = mat,
    cell_types = colnames(y),
    img = FALSE,
    scatterpie_alpha = 1,
    pie_scale = 0.4) +
    scale_fill_manual(
      values = pal,
      breaks = names(pal))+ scale_y_reverse()
  print(p1)  
}

# Evaluate the area under Precision Recall curve separately for each cell type, then treat all classes equally and take the average across remaining 5 GC cell types
library(ROCR)
average_PR_curve.area = rep(0,7)
names(average_PR_curve.area) = c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC","no")

df_score <- data.frame(cell_type=NULL, deleted_type=NULL,
value=NULL)  

for (deleted_type in c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC","no") )
{
  if(deleted_type == "no"){
    RCTD_file = paste(dataset_name,"full_RCTD.rds",sep ="_")
  }else{
    RCTD_file = paste(dataset_name,"full_RCTD", deleted_type, ".rds",sep ="_")
  }
  rds_file = RCTD_file
  myRCTD_deleted <- readRDS(rds_file)   
  normalized_weights_deleted <- normalize_weights(myRCTD_deleted@results$weights)
  # normalize the cell type proportions to sum to 1.
  matrix_result_deleted = as(normalized_weights_deleted,'matrix')
  sum_PR_curve.area = 0
  i = 1  
  # Evaluate the area under Precision Recall curve separately for each cell type
  for (remain_type in c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC") )
{
    if(remain_type == deleted_type){
      next
    }
     # extract remain_type's predicted proportions in each spot
    proportions = as.data.frame(matrix_result_deleted[,remain_type]) 
    colnames(proportions) = "proportion"
    manual_GC_annotation[manual_GC_annotation$GC == '',"GC"] = 0
    manual_GC_annotation[manual_GC_annotation$GC == 'GC',"GC"] = 1
    proportions$Barcode = rownames(proportions)
    manual_GC_label = manual_GC_annotation[manual_GC_annotation$Barcode != "ACTGTAGCACTTTGGA-1",]
    manual_GC_label = manual_GC_label[manual_GC_label$Barcode != "CCCGCCATGCTCCCGT-1",]
    proportions_label = merge(proportions,manual_GC_label,by = "Barcode")
    
    
     
    library(ggplot2)   

    if(remain_type != "FDC"){
        print(paste("deleted_type:",deleted_type,"remain_type:",remain_type,
            "predicted_positive:",sum(proportions_label$proportion>0),
            "actual_positive:",sum(proportions_label$GC>0)))
      
        # print("GC spot")
        # print(table(proportions_label[proportions_label$GC==1,"proportion"]))
        # print("non-GC spot")
        # print(table(proportions_label[proportions_label$GC==0,"proportion"]))
    p3 = ggplot(proportions_label[proportions_label$proportion>0,], aes(x=proportion))+
    geom_histogram(color="darkblue", fill="lightblue") + facet_grid(. ~ GC)+
     labs(title = paste("deleted_type:",deleted_type,"remain_type:",remain_type),
              subtitle = paste(
            "predicted_positive:",sum(proportions_label$proportion>0),
            "actual_positive:",sum(proportions_label$GC>0)))
    print(p3)

    }
 
    
    pred <- prediction(proportions_label[,2],proportions_label[,3])
    
    PR_curve.area = performance(pred, measure = "aucpr")
    
    sum_PR_curve.area = PR_curve.area@y.values[[1]] + sum_PR_curve.area
    
    df_tmp <- data.frame(cell_type=remain_type, deleted_type=deleted_type,
    value=PR_curve.area@y.values[[1]])
    df_score <- rbind(df_score,df_tmp) 
  
    PR_curve = performance(pred, measure = "prec", x.measure = "rec")
    if(i ==1)
  {
  df_all <- data.frame(cell_type=rep(remain_type,each=length(PR_curve@x.values[[1]])), 
  recall=c(PR_curve@x.values[[1]]),
  precision=c(PR_curve@y.values[[1]]))
  

  i = i + 1 
    }else{

  df <- data.frame(cell_type=rep(remain_type,each=length(PR_curve@x.values[[1]])), 
  recall=c(PR_curve@x.values[[1]]),
  precision=c(PR_curve@y.values[[1]]))
  df_all =  rbind(df_all,df)    
  i = i + 1 
  }
  
  }
  library(ggplot2)
  df_all$cell_type = as.factor(df_all$cell_type)
  plt <- ggplot(df_all, aes(x=recall, y=precision, color=cell_type)) + geom_line()+
    ggtitle( paste(deleted_type,"cell type deleted in reference"))
 
  print(plt)
  average_PR_curve.area[deleted_type] = sum_PR_curve.area/5

  
}
## [1] "deleted_type: B_GC_DZ remain_type: B_GC_LZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_DZ remain_type: B_GC_prePB predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_DZ remain_type: T_CD4+_TfH_GC predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_DZ remain_type: B_Cycling predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## Warning: Removed 5 row(s) containing missing values (geom_path).

## [1] "deleted_type: B_GC_LZ remain_type: B_GC_DZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_LZ remain_type: B_GC_prePB predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_LZ remain_type: T_CD4+_TfH_GC predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_LZ remain_type: B_Cycling predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## Warning: Removed 5 row(s) containing missing values (geom_path).

## [1] "deleted_type: B_GC_prePB remain_type: B_GC_DZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_prePB remain_type: B_GC_LZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_prePB remain_type: T_CD4+_TfH_GC predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_GC_prePB remain_type: B_Cycling predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## Warning: Removed 5 row(s) containing missing values (geom_path).

## [1] "deleted_type: T_CD4+_TfH_GC remain_type: B_GC_DZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: T_CD4+_TfH_GC remain_type: B_GC_LZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: T_CD4+_TfH_GC remain_type: B_GC_prePB predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: T_CD4+_TfH_GC remain_type: B_Cycling predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## Warning: Removed 5 row(s) containing missing values (geom_path).

## [1] "deleted_type: B_Cycling remain_type: B_GC_DZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_Cycling remain_type: B_GC_LZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_Cycling remain_type: B_GC_prePB predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: B_Cycling remain_type: T_CD4+_TfH_GC predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## Warning: Removed 5 row(s) containing missing values (geom_path).

## [1] "deleted_type: FDC remain_type: B_GC_DZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: FDC remain_type: B_GC_LZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: FDC remain_type: B_GC_prePB predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: FDC remain_type: T_CD4+_TfH_GC predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: FDC remain_type: B_Cycling predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## Warning: Removed 5 row(s) containing missing values (geom_path).

## [1] "deleted_type: no remain_type: B_GC_DZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: no remain_type: B_GC_LZ predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: no remain_type: B_GC_prePB predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: no remain_type: T_CD4+_TfH_GC predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## [1] "deleted_type: no remain_type: B_Cycling predicted_positive: 4033 actual_positive: 378"
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

## Warning: Removed 6 row(s) containing missing values (geom_path).

df <- data.frame(deleted_cell_type=names(average_PR_curve.area),
                average_area_PR_curve=average_PR_curve.area, similarity = "similar"  )
df[df$deleted_cell_type %in% c("B_Cycling" , "FDC"),"similarity"] = "distinct"

ref_line = df["no",2]
df = df[df$deleted_cell_type!="no",]

p <-ggplot(data=df, aes(x=deleted_cell_type,y=average_area_PR_curve)) +  geom_bar(stat="identity") + geom_hline(yintercept=ref_line, linetype="dashed", color = "red")
p

rmsd_summary = rep(0,6)
names(rmsd_summary) = c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC")


for (selected_type in c("B_GC_DZ", "B_GC_LZ", "B_GC_prePB","T_CD4+_TfH_GC",
                 "B_Cycling",  "FDC") )
{
  matrix_result_complete = norm_weights_matrix
  RCTD_file = paste(dataset_name,"full_RCTD", selected_type, ".rds",sep ="_")
  rds_file = RCTD_file
  myRCTD_deleted <- readRDS(rds_file)   
  normalized_weights_deleted <- normalize_weights(myRCTD_deleted@results$weights)
  # normalize the cell type proportions to sum to 1.
  matrix_result_deleted = as(normalized_weights_deleted,'matrix')
  matrix_result_deleted = as.data.frame(matrix_result_deleted)
  # matrix_result_deleted$selected_type = rep(0,nrow(matrix_result_deleted))
  # selected_column = matrix_result_complete[,selected_type]
  matrix_result_complete = as.data.frame(matrix_result_complete)
  df = matrix_result_complete[,!(names(matrix_result_complete) == selected_type)]
  # df$selected_type = selected_column
  matrix_result_complete = as.matrix(df)
  matrix_result_deleted = as.matrix(matrix_result_deleted)
  
  
  # for each spatial spot, Calculate the root-mean-square error (RMSE) distance between the predicted proportion under different reference dataset situations
  # load("/data/home/lyang/Visium_RCTD/.RData")
  
  
  matrix_substrated = matrix_result_complete - matrix_result_deleted
  rmsd_per_spot = sqrt(rowSums(matrix_substrated^2)/ncol(matrix_substrated))
  rmsd_per_situation = sum(rmsd_per_spot)
  rmsd_summary[selected_type] = rmsd_per_situation
  
}


df <- data.frame(deleted_cell_type=names(rmsd_summary),
                RMSD=rmsd_summary  )
p<-ggplot(data=df, aes(x=deleted_cell_type, y=RMSD)) +
  geom_bar(stat="identity")
p